#!/usr/bin/env python3
"""
Script to synchronize (local>remote and viceversa) test data files from/to GCS.

//test/data files are not checked in the codebase because they are large binary
file and change frequently. Instead we check-in only xxx.sha256 files, which
contain the SHA-256 of the actual binary file, and sync them from a GCS bucket.

File in the GCS bucket are content-indexed as gs://bucket/file_name-a1b2c3f4 .

Usage:
./test_data status     # Prints the status of new & modified files.
./test_data download   # To sync remote>local (used by install-build-deps).
./test_data upload     # To upload newly created and modified files.
"""

import argparse
import logging
import os
import sys
import hashlib
import subprocess
import random
import time

from multiprocessing.pool import ThreadPool
from collections import namedtuple, defaultdict

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
BUCKET = 'gs://perfetto/test_data'
SUFFIX = '.sha256'

FS_MATCH = 'matches'
FS_NEW_FILE = 'needs upload'
FS_MODIFIED = 'modified'
FS_MISSING = 'needs download'

FileStat = namedtuple('FileStat',
                      ['path', 'status', 'actual_digest', 'expected_digest'])
args = None


def relpath(path):
  return os.path.relpath(path, ROOT_DIR)


def download(url, out_file):
  MAX_RETRIES = 3
  cmd = ['curl', '-L', '-s', '-o', out_file, url]
  for attempt in range(1, MAX_RETRIES + 1):
    try:
      subprocess.check_call(cmd)
      return  # success
    except subprocess.CalledProcessError as e:
      if attempt == MAX_RETRIES:
        raise e
      sleep_time = random.random() * 5 * attempt
      print(f"${' '.join(cmd)} failed, retrying in {round(sleep_time)} sec")
      time.sleep(sleep_time)


def list_files(path, scan_new_files=False):
  """ List files recursively in path.

  If scan_new_files=False, returns only files with a maching xxx.sha256 tracker.
  If scan_new_files=True returns all files including untracked ones.
  """
  seen = set()
  for root, _, files in os.walk(path):
    for fname in files:
      if fname.endswith('.swp'):
        continue  # Temporary files left around if CTRL-C-ing while downloading.
      if fname in ["OWNERS", "README.md"]:
        continue  # OWNERS or README.md file should not be uploaded.
      fpath = os.path.join(root, fname)
      if not os.path.isfile(fpath) or fname.startswith('.'):
        continue
      if fpath.endswith(SUFFIX):
        fpath = fpath[:-len(SUFFIX)]
      elif not scan_new_files:
        continue
      if fpath not in seen:
        seen.add(fpath)
        yield fpath


def hash_file(fpath):
  hasher = hashlib.sha256()
  with open(fpath, 'rb') as f:
    for chunk in iter(lambda: f.read(32768), b''):
      hasher.update(chunk)
  return hasher.hexdigest()


def map_concurrently(fn, files):
  done = 0
  for fs in ThreadPool(args.jobs).imap_unordered(fn, files):
    assert (isinstance(fs, FileStat))
    done += 1
    if not args.quiet:
      print(
          '[%d/%d] %-60s' % (done, len(files), relpath(fs.path)[-60:]),
          end='\r')
  if not args.quiet:
    print('')


def get_file_status(fpath):
  sha_file = fpath + SUFFIX
  sha_exists = os.path.exists(sha_file)
  file_exists = os.path.exists(fpath)
  actual_digest = None
  expected_digest = None
  if sha_exists:
    with open(sha_file, 'r') as f:
      expected_digest = f.readline().strip()
  if file_exists:
    actual_digest = hash_file(fpath)
  if sha_exists and not file_exists:
    status = FS_MISSING
  elif not sha_exists and file_exists:
    status = FS_NEW_FILE
  elif not sha_exists and not file_exists:
    raise Exception(fpath)
  elif expected_digest == actual_digest:
    status = FS_MATCH
  else:
    status = FS_MODIFIED
  return FileStat(fpath, status, actual_digest, expected_digest)


def cmd_upload(dir):
  all_files = list_files(dir, scan_new_files=True)
  files_to_upload = []
  for fs in ThreadPool(args.jobs).imap_unordered(get_file_status, all_files):
    if fs.status in (FS_NEW_FILE, FS_MODIFIED):
      files_to_upload.append(fs)
  if len(files_to_upload) == 0:
    if not args.quiet:
      print('No modified or new files require uploading')
    return 0
  if args.dry_run:
    return 0
  if not args.quiet:
    print('\n'.join(relpath(f.path) for f in files_to_upload))
    print('')
    print('About to upload %d files' % len(files_to_upload))
    input('Press a key to continue or CTRL-C to abort')

  def upload_one_file(fs):
    assert (fs.actual_digest is not None)
    dst_name = '%s/%s-%s' % (args.bucket, os.path.basename(
        fs.path), fs.actual_digest)
    cmd = ['gsutil', '-q', 'cp', '-n', '-a', 'public-read', fs.path, dst_name]
    logging.debug(' '.join(cmd))
    subprocess.check_call(cmd)
    with open(fs.path + SUFFIX + '.swp', 'w') as f:
      f.write(fs.actual_digest)
    os.replace(fs.path + SUFFIX + '.swp', fs.path + SUFFIX)
    return fs

  map_concurrently(upload_one_file, files_to_upload)
  return 0


def cmd_clean(dir):
  all_files = list_files(dir, scan_new_files=True)
  files_to_clean = []
  for fs in ThreadPool(args.jobs).imap_unordered(get_file_status, all_files):
    if fs.status in (FS_NEW_FILE, FS_MODIFIED):
      files_to_clean.append(fs.path)
  if len(files_to_clean) == 0:
    if not args.quiet:
      print('No modified or new files require cleaning')
    return 0
  if args.dry_run:
    return 0
  if not args.quiet:
    print('\n'.join(relpath(f) for f in files_to_clean))
    print('')
    print('About to remove %d files' % len(files_to_clean))
    input('Press a key to continue or CTRL-C to abort')
  list(map(os.remove, files_to_clean))
  return 0


def cmd_download(dir, overwrite_locally_modified=False):
  files_to_download = []
  modified = []
  all_files = list_files(dir, scan_new_files=False)
  for fs in ThreadPool(args.jobs).imap_unordered(get_file_status, all_files):
    if fs.status == FS_MISSING:
      files_to_download.append(fs)
    elif fs.status == FS_MODIFIED:
      modified.append(fs)

  if len(modified) > 0 and not overwrite_locally_modified:
    print('WARNING: The following files diverged locally and will NOT be ' +
          'overwritten if you continue')
    print('\n'.join(relpath(f.path) for f in modified))
    print('')
    print('Re run `download --overwrite` to overwrite locally modified files')
    print('or `upload` to sync them on the GCS bucket')
    print('')
    input('Press a key to continue or CTRL-C to abort')
  elif overwrite_locally_modified:
    files_to_download += modified

  if len(files_to_download) == 0:
    if not args.quiet:
      print('Nothing to do, all files are synced')
    return 0

  if not args.quiet:
    print('Downloading %d files in //%s' %
          (len(files_to_download), relpath(args.dir)))
  if args.dry_run:
    print('\n'.join(files_to_download))
    return

  def download_one_file(fs):
    assert (fs.expected_digest is not None)
    uri = '%s/%s-%s' % (args.bucket, os.path.basename(
        fs.path), fs.expected_digest)
    uri = uri.replace('gs://', 'https://storage.googleapis.com/')
    logging.debug(uri)
    tmp_path = fs.path + '.swp'
    download(uri, tmp_path)
    digest = hash_file(tmp_path)
    if digest != fs.expected_digest:
      raise Exception('Mismatching digest for %s. expected=%s, actual=%s' %
                      (uri, fs.expected_digest, digest))
    os.replace(tmp_path, fs.path)
    return fs

  map_concurrently(download_one_file, files_to_download)
  return 0


def cmd_status(dir):
  files = list_files(dir, scan_new_files=True)
  file_by_status = defaultdict(list)
  num_files = 0
  num_out_of_sync = 0
  for fs in ThreadPool(args.jobs).imap_unordered(get_file_status, files):
    file_by_status[fs.status].append(relpath(fs.path))
    num_files += 1
  for status, rpaths in sorted(file_by_status.items()):
    if status == FS_NEW_FILE and args.ignore_new:
      continue
    if status != FS_MATCH:
      for rpath in rpaths:
        num_out_of_sync += 1
        if not args.quiet:
          print('%-15s: %s' % (status, rpath))
  if num_out_of_sync == 0:
    if not args.quiet:
      print('Scanned %d files in //%s, everything in sync.' %
            (num_files, relpath(dir)))
    return 0
  return 1


def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('--dir', default=os.path.join(ROOT_DIR, 'test/data'))
  parser.add_argument('--overwrite', action='store_true')
  parser.add_argument('--bucket', default=BUCKET)
  parser.add_argument('--jobs', '-j', default=8, type=int)
  parser.add_argument('--dry-run', '-n', action='store_true')
  parser.add_argument('--quiet', '-q', action='store_true')
  parser.add_argument('--verbose', '-v', action='store_true')
  parser.add_argument('--ignore-new', action='store_true')
  parser.add_argument('cmd', choices=['status', 'download', 'upload', 'clean'])
  global args
  args = parser.parse_args()
  logging.basicConfig(
      format='%(asctime)s %(levelname).1s %(message)s',
      level=logging.DEBUG if args.verbose else logging.INFO,
      datefmt=r'%H:%M:%S')
  if args.cmd == 'status':
    return cmd_status(args.dir)
  if args.cmd == 'download':
    return cmd_download(args.dir, overwrite_locally_modified=args.overwrite)
  if args.cmd == 'upload':
    return cmd_upload(args.dir)
  if args.cmd == 'clean':
    return cmd_clean(args.dir)
  print('Unknown command: %s' % args.cmd)


if __name__ == '__main__':
  sys.exit(main())
